-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve MPT fp8 #1256
Improve MPT fp8 #1256
Conversation
python run_generation.py --model_name_or_path mosaicml/mpt-7b --use_hpu_graphs --use_kv_cache --limit_hpu_graphs --max_input_tokens 128 --max_new_tokens 128 --batch_size 256 --bf16
python run_generation.py --model_name_or_path mosaicml/mpt-7b --use_hpu_graphs --use_kv_cache --limit_hpu_graphs --max_input_tokens 128 --max_new_tokens 128 --batch_size 256 --bf16 --use_flash_attention
|
flash_attention_recompute: Optional[bool] = False, | ||
): | ||
""" | ||
Copied from MptAttention.forward: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/modeling_mpt.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least part of the code looks like copied from newer version than v4.32.0, could you verify and update this comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is original, line 123
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still please update to latest as we are copying code from latest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge MptAttention forward r4.44.1
30d8302
to
f7704e4
Compare
5f901df
to
b3f729f
Compare
Add Softmax and FusedSDPA Update GaudiMptAttention foward to r4.44.1 base Co-authored-by: Thanaji Rao Thakkalapelli <[email protected]>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@atakaha , pls use make style to fix code format issue |
…etes example (huggingface#1286) Signed-off-by: dmsuehir <[email protected]>
Add Softmax and FusedSDPA Update GaudiMptAttention foward to r4.44.1 base Co-authored-by: Thanaji Rao Thakkalapelli <[email protected]>
@libinta , please have somebody review this PR.thanks |
attn_weights = None | ||
else: | ||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should also use fp8 matmul kernel. Please check falcon code below.
https://github.com/huggingface/optimum-habana/blob/main/optimum/habana/transformers/models/falcon/modeling_falcon.py#L108
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested it and throughput slowdown. The trace showed if MPT use it then it doesn't use fp8, but torch.matmul use fp8. That's the why MPT doesn't use the fp8 matmul kernel this time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's weird. Wondering if fp8_sofmax actually forcing matmul to also run with fp8. Do you see a good accuracy with fp8 softmax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return torch.matmul(x, y)
Compare profile data this between FP8, for GPT-J, Mistral and MPT.
- GPT-J and Mistral spend <70% for MME but MPT spend ~30%.
- GPT-J and Mistral calling index_copy_fwd_hf8 and cast_bf16_to_hf8 kernels but doesn't appear these kernel call in MPT. When add matmul_qk and matmul_qv in blocklist of maxabs_quant.json for MPT then it call same kernels of GPT-J/Mistral.
This is reason we didn't add this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, can you make sure your change doesn't break anything for the training case since model file is used for both.
if use_flash_attention and FusedSDPA: | ||
import habana_frameworks.torch.hpu as ht | ||
|
||
with ht.sdp_kernel(enable_recompute=flash_attention_recompute): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also check other model if this enable_recompute should be set based on the fp8/bf16 and q_len.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If set enable_recompute same way of other models, then FP8 throughput drop to half. In the trace, softmax_stage1_fwd_f32 appears and spending a lot of time, There is not this process at enable_recompute = False case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's your fp8 command with flash_attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you test any case which enable causal_mask and enable_compute?
How about longer prompt? Usually causal_mask shows better perf for long prompt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's your fp8 command with flash_attention?
The command line is
QUANT_CONFIG=./quantization_config/maxabs_quant.json \
python run_generation.py \
--model_name_or_path mosaicml/mpt-7b \
--use_hpu_graphs \
--use_kv_cache \
--limit_hpu_graphs \
--max_input_tokens 128 \
--max_new_tokens 128 \
--batch_size 128 \
--bf16 \
--use_flash_attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you test any case which enable causal_mask and enable_compute? How about longer prompt? Usually causal_mask shows better perf for long prompt
No, I haven't test it these cases.
super().__init__() | ||
|
||
def forward(self, x, dim=None, invAttnHead=None): | ||
return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since INC is enabled, please use torch.nn.functional.softmax
as its supported module by INC for quantization.
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html#supported-functions
return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead) | |
return torch.nn.functional.softmax(x, dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated the code with your suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
Could you investigate further though to see if we need causal_mask and recompute enabled differently for long prompt and submit as a separate patch if needed?
The code quality check failed, please run |
fixed ruff error. |
Sure, I will. |
@regisss, Please review this PR. |
Add Softmax and FusedSDPA
Fix unnecessary args from self._gradient_checkpointing_func() call.
What does this PR do?
Fixes # (issue)
Before submitting